import cv2
import numpy as np
import xml.etree.ElementTree as ET
import os.path as osp
import os
# 自定义的图像预处理函数
from read_tiff import tiff_16bit_img_read_and_normalize as imgread

FOLDER = 'VOC2012'

def intersection_over_origin(target, origin):
    """calculate intersection of two boundboxes over the origin one.

    Args:
        target: list of xmin, ymin, xmax, ymax;
        origin: list of xmin, ymin, xmax, ymax;
    Returns:
        a float number of ioo between two inputs.
    """
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(target[0], origin[0])
    yA = max(target[1], origin[1])
    xB = min(target[2], origin[2])
    yB = min(target[3], origin[3])

    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)

    OriginBoxArea = (origin[2] - origin[0] + 1) * (origin[3] - origin[1] + 1)

    ioo = interArea / float(OriginBoxArea)

    return ioo

def voc_basic_xml_element_generator(filename, folder, size_w, size_h, size_c):
    '''
    description: generat some common configs of a VOC label
    param {*}
    return {*}
    '''
    fd = ET.Element('folder')
    fd.text = folder
    sz = ET.Element('size')

    width = ET.Element('width')
    width.text = str(size_w)
    height = ET.Element('height')
    height.text = str(size_h)
    depth = ET.Element('depth')
    depth.text = str(size_c)
    for ele in (width, height, depth):
        sz.append(ele)

    return (fd, sz)

def voc_object_xml_element_resolver(object_element):
    '''
    description: convert the object element of a VOC xml into a dict
    param:
        object_element: the object to convert
    return:
        dict_obj: converted dict
    '''
    box = object_element.find('bndbox')
    dict_obj = {
        'name' : object_element.findtext('name'),
        'pose' : object_element.findtext('pose'),
        'truncated' : object_element.findtext('truncated'),
        'difficult' : object_element.findtext('difficult'),
        'bndbox' : {
            'xmin' : box.findtext('xmin'),
            'ymin' : box.findtext('ymin'),
            'xmax' : box.findtext('xmax'),
            'ymax' : box.findtext('ymax'),
        }
    }
    return dict_obj


def voc_object_xml_element_generator(object_dict):
    '''
    description: convert the dict of a VOC object into a XML element
    param:
        object_dict: the dict to convert
    return:
        obj: converted element
    '''
    obj = ET.Element('object')
    for k, v in object_dict.items():
        obj_ele = ET.Element(k)
        if k != 'bndbox':
            obj_ele.text = v
        else:
            for k1, v1 in v.items():
                obj_cord = ET.Element(k1)
                obj_cord.text = str(v1)
                obj_ele.append(obj_cord)
        obj.append(obj_ele)
    return obj
        
    

def voc_subimg_box_generator(input_w, input_h, subimg_w, subimg_h, overlap_w, overlap_h):
    '''
    description: generate locations of the sub-images
    param {*}
    return {*}
    '''
    offset_w = subimg_w - overlap_w
    offset_h = subimg_h - overlap_h
    # 对原始图像进行分割
    # TODO:将多余的数据也剪裁出来
    boxes = []
    step_w, step_h = (0, 0)
    while step_h * offset_h + subimg_h <= input_h:
        step_w = 0
        while step_w * offset_w + subimg_w <= input_w:
            # (xmin, ymin, xmax, ymax)
            subimg = (step_w * offset_w, step_h * offset_h, step_w * offset_w + subimg_w, step_h * offset_h + subimg_h)
            boxes.append(subimg)
            step_w += 1
        step_h += 1
    return ((step_w, step_h), boxes)

def voc_label_preprocess(input_anno):
    '''
    description: extract basic info and objects in a annotation file seperately
    param {*}
    return {*}
    '''
    # 读取标注文件
    anno_tree = ET.parse(input_anno)
    anno_root = anno_tree.getroot()    
    # 缓存初始化
    anno_info = []
    anno_obj = []
    for child in anno_root:
        # 提取出标注的目标
        if child.tag == 'object':
            anno_obj.append(child)
        # 提取出不需要改变的部分
        elif child.tag != 'filename' and child.tag != 'size' and child.tag != 'path' and child.tag != 'folder':
            anno_info.append(child)
    return (anno_info, anno_obj)

def voc_split_annotation(anno_info, anno_basic, anno_obj, split_boxes, ioo_thre=0.5):

    subannotation = []
    for idx, box in enumerate(split_boxes):
        xmin, ymin, xmax, ymax = box
        root = ET.Element('annotation')
        tree = ET.ElementTree(root)
        for info in anno_info:
            root.append(info)
        fn = ET.Element('filename')
        root.append(fn)
        for basic in anno_basic:
            root.append(basic)
        for obj in anno_obj:
            dict_obj = voc_object_xml_element_resolver(obj)
            ob_xmin, ob_ymin, ob_xmax, ob_ymax = \
            int(dict_obj['bndbox']['xmin']), int(dict_obj['bndbox']['ymin']), int(dict_obj['bndbox']['xmax']), int(dict_obj['bndbox']['ymax'])
            ioo = intersection_over_origin((xmin, ymin, xmax, ymax), (ob_xmin, ob_ymin, ob_xmax, ob_ymax))
            # 判断是否需要truncated
            if ioo < ioo_thre :
                continue
            elif (1 - ioo) > 1e-5:
                dict_obj['truncated'] = str(1)
            else:
                if dict_obj['truncated'] == None:
                    dict_obj['truncated'] = str(0)
            # 这边应否和1做比较？？
            dict_obj['bndbox']['xmin'] = max(ob_xmin - xmin, 1)
            dict_obj['bndbox']['ymin'] = max(ob_ymin - ymin, 1)
            dict_obj['bndbox']['xmax'] = min(ob_xmax - xmin, xmax - xmin - 1)
            dict_obj['bndbox']['ymax'] = min(ob_ymax - ymin, ymax - ymin - 1)
            if dict_obj['difficult'] == None:
                dict_obj['difficult'] = str(0)
            ele_obj = voc_object_xml_element_generator(dict_obj)
            root.append(ele_obj)
        # 修改此子图的name属性，使之与目标图片匹配
        idx += 1
        subannotation.append(tree)
    return subannotation


def voc_split_image_and_annotation(fun_imread, input_pic, input_anno, output_dir, width, height, overlap_w=0, overlap_h=0, ioo_thresh=0.5, ignore_empty=False):

    im = fun_imread(input_pic)
    # 获取正确的文件名
    im_n = osp.basename(input_pic).split('.')[:-1]
    im_name = ''
    for i in range(len(im_n) - 1):
        im_name += im_n[i] + '.'
    im_name += im_n[-1]
    # 检查输出的文件夹是否存在，不存在则创建
    if not osp.exists(output_dir):
        os.mkdir(output_dir)
    # 获取图像的长宽
    im_h, im_w, im_c = im.shape
    # 获取要截取的子图像位置
    steps, boxes = voc_subimg_box_generator(im_w, im_h, width, height, overlap_w, overlap_h)
    # 生成新标注文件中的文件名、文件夹、图像长、宽、通道数信息
    anno_basic = voc_basic_xml_element_generator(im_name, FOLDER, width, height, im_c)
    # 读取原始标签，获得原始标签中独有的信息以及bbox的信息
    anno_info, anno_obj = voc_label_preprocess(input_anno)
    # 分割标注
    sub_anno = voc_split_annotation(anno_info, anno_basic, anno_obj, boxes)
    # 忽略裁剪出来的没有目标的图
    if ignore_empty:
        not_empty = []
        old_boxes = boxes.copy()
        old_annos = sub_anno.copy()
        boxes = []
        sub_anno = []
        for k, v in enumerate(old_annos):
            if v.find('object') != None:
                print(k)
                not_empty.append(k)
        for idx in not_empty:
            boxes.append(old_boxes[idx])
            sub_anno.append(old_annos[idx])
    # 对原图进行分割并保存
    img = np.asarray(im)
    for k, v in enumerate(boxes):
        subimg = img[v[1]: v[3], v[0]: v[2], :]
        cv2.imwrite(osp.join(output_dir, 'JPEGImages', im_name+'_{:04d}'.format(k)+'.jpg'), subimg)
    # 保存标注数据
    for k, v in enumerate(sub_anno):
        fn = v.find('filename')
        fn.text = im_name+'_{:04d}'.format(k)+'.jpg'
        v.write(osp.join(output_dir, 'Annotations', im_name+'_{:04d}'.format(k)+'.xml'), encoding='utf-8', xml_declaration=True)


if __name__ == '__main__':
    IMG = 'SARShip-1.0-28/SARShip-1.0-28.tiff'
    ANNO = 'SARShip-1.0-28/SARShip-1.0-28-label.xml'
    OUTPUT_DIR = './saved4'
      
    # voc_split_image_and_annotation(cv2.imread, IMG, ANNO, OUTPUT_DIR, 3000, 3000, 0, 0, ignore_empty=True)
    voc_split_image_and_annotation(imgread, IMG, ANNO, OUTPUT_DIR, 3000, 3000, 0, 0, ignore_empty=True)
